{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.utils"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model utilities"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">Utility functions used to build PyTorch timeseries models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from copy import deepcopy\n",
    "from fastai.layers import flatten_model, params, apply_init\n",
    "from fastai.learner import Learner\n",
    "from fastai.data.transforms import get_c\n",
    "from fastai.tabular.model import *\n",
    "from fastai.callback.schedule import *\n",
    "from fastai.vision.models.xresnet import *\n",
    "from tsai.models.layers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def apply_idxs(o, idxs):\n",
    "    \"Function to apply indices to zarr, dask and numpy arrays\"\n",
    "    if is_zarr(o): return o.oindex[idxs]\n",
    "    elif is_dask(o): return o[idxs].compute()\n",
    "    else: return o[idxs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def SeqTokenizer(c_in, embed_dim, token_size=60, norm=False):\n",
    "    \"Generates non-overlapping tokens from sub-sequences within a sequence by applying a sliding window\"\n",
    "    return ConvBlock(c_in, embed_dim, token_size, stride=token_size, padding=0, act=None,\n",
    "                     norm='Batch' if norm else None, bias=norm is None)\n",
    "\n",
    "SeqEmbed = SeqTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def get_embed_size(n_cat, rule='log2'):\n",
    "    if rule == 'log2':\n",
    "        return int(np.ceil(np.log2(n_cat)))\n",
    "    elif rule == 'thumb':\n",
    "        return min(600, round(1.6 * n_cat**0.56)) # fastai's"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(get_embed_size(35), 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def get_layers(model, cond=noop, full=True):\n",
    "    if isinstance(model, Learner): model=model.model\n",
    "    if full: return [m for m in flatten_model(model) if any([c(m) for c in L(cond)])]\n",
    "    else: return [m for m in model if any([c(m) for c in L(cond)])]\n",
    "\n",
    "def is_layer(*args):\n",
    "    def _is_layer(l, cond=args):\n",
    "        return isinstance(l, cond)\n",
    "    return partial(_is_layer, cond=args)\n",
    "\n",
    "def is_linear(l):\n",
    "    return isinstance(l, nn.Linear)\n",
    "\n",
    "def is_bn(l):\n",
    "    types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)\n",
    "    return isinstance(l, types)\n",
    "\n",
    "def is_conv_linear(l):\n",
    "    types = (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)\n",
    "    return isinstance(l, types)\n",
    "\n",
    "def is_affine_layer(l):\n",
    "    return has_bias(l) or has_weight(l)\n",
    "\n",
    "def is_conv(l):\n",
    "    types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)\n",
    "    return isinstance(l, types)\n",
    "\n",
    "def has_bias(l):\n",
    "    return (hasattr(l, 'bias') and l.bias is not None)\n",
    "\n",
    "def has_weight(l):\n",
    "    return (hasattr(l, 'weight'))\n",
    "\n",
    "def has_weight_or_bias(l):\n",
    "    return any((has_weight(l), has_bias(l)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def check_bias(m, cond=noop, verbose=False):\n",
    "    mean, std = [], []\n",
    "    for i,l in enumerate(get_layers(m, cond=cond)):\n",
    "        if hasattr(l, 'bias') and l.bias is not None:\n",
    "            b = l.bias.data\n",
    "            mean.append(b.mean())\n",
    "            std.append(b.std())\n",
    "            pv(f'{i:3} {l.__class__.__name__:15} shape: {str(list(b.shape)):15}  mean: {b.mean():+6.3f}  std: {b.std():+6.3f}', verbose)\n",
    "    return np.array(mean), np.array(std)\n",
    "\n",
    "def check_weight(m, cond=noop, verbose=False):\n",
    "    mean, std = [], []\n",
    "    for i,l in enumerate(get_layers(m, cond=cond)):\n",
    "        if hasattr(l, 'weight') and l.weight is not None:\n",
    "            w = l.weight.data\n",
    "            mean.append(w.mean())\n",
    "            std.append(w.std())\n",
    "            pv(f'{i:3} {l.__class__.__name__:15} shape: {str(list(w.shape)):15}  mean: {w.mean():+6.3f}  std: {w.std():+6.3f}', verbose)\n",
    "    return np.array(mean), np.array(std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def get_nf(m):\n",
    "    \"Get nf from model's first linear layer in head\"\n",
    "    return get_layers(m[-1], is_linear)[0].in_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def ts_splitter(m):\n",
    "    \"Split of a model between body and head\"\n",
    "    return L(m.backbone, m.head).map(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def transfer_weights(model, weights_path:Path, device:torch.device=None, exclude_head:bool=True):\n",
    "    \"\"\"Utility function that allows to easily transfer weights between models.\n",
    "    Taken from the great self-supervised repository created by Kerem Turgutlu.\n",
    "    https://github.com/KeremTurgutlu/self_supervised/blob/d87ebd9b4961c7da0efd6073c42782bbc61aaa2e/self_supervised/utils.py\"\"\"\n",
    "\n",
    "    device = ifnone(device, default_device())\n",
    "    state_dict = model.state_dict()\n",
    "    new_state_dict = torch.load(weights_path, map_location=device)\n",
    "    matched_layers = 0\n",
    "    unmatched_layers = []\n",
    "    for name, param in state_dict.items():\n",
    "        if exclude_head and 'head' in name: continue\n",
    "        if name in new_state_dict:\n",
    "            matched_layers += 1\n",
    "            input_param = new_state_dict[name]\n",
    "            if input_param.shape == param.shape: param.copy_(input_param)\n",
    "            else: unmatched_layers.append(name)\n",
    "        else:\n",
    "            unmatched_layers.append(name)\n",
    "            pass # these are weights that weren't in the original model, such as a new head\n",
    "    if matched_layers == 0: raise Exception(\"No shared weight names were found between the models\")\n",
    "    else:\n",
    "        if len(unmatched_layers) > 0:\n",
    "            print(f'check unmatched_layers: {unmatched_layers}')\n",
    "        else:\n",
    "            print(f\"weights from {weights_path} successfully transferred!\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def build_ts_model(arch, c_in=None, c_out=None, seq_len=None, d=None, dls=None, device=None, verbose=False,\n",
    "                   s_cat_idxs=None, s_cat_embeddings=None, s_cat_embedding_dims=None, s_cont_idxs=None,\n",
    "                   o_cat_idxs=None, o_cat_embeddings=None, o_cat_embedding_dims=None, o_cont_idxs=None,\n",
    "                   patch_len=None, patch_stride=None, fusion_layers=128, fusion_act='relu', fusion_dropout=0., fusion_use_bn=True,\n",
    "                   pretrained=False, weights_path=None, exclude_head=True, cut=-1, init=None, arch_config={}, **kwargs):\n",
    "\n",
    "    device = ifnone(device, default_device())\n",
    "    if dls is not None:\n",
    "        c_in = ifnone(c_in, dls.vars)\n",
    "        c_out = ifnone(c_out, dls.c)\n",
    "        seq_len = ifnone(seq_len, dls.len)\n",
    "        d = ifnone(d, dls.d)\n",
    "\n",
    "    if s_cat_idxs or s_cat_embeddings or s_cat_embedding_dims or s_cont_idxs or o_cat_idxs or o_cat_embeddings or o_cat_embedding_dims or o_cont_idxs:\n",
    "        from tsai.models.multimodal import MultInputWrapper\n",
    "        model = MultInputWrapper(arch, c_in=c_in, c_out=c_out, seq_len=seq_len, d=d,\n",
    "                                 s_cat_idxs=s_cat_idxs, s_cat_embeddings=s_cat_embeddings, s_cat_embedding_dims=s_cat_embedding_dims, s_cont_idxs=s_cont_idxs,\n",
    "                                 o_cat_idxs=o_cat_idxs, o_cat_embeddings=o_cat_embeddings, o_cat_embedding_dims=o_cat_embedding_dims, o_cont_idxs=o_cont_idxs,\n",
    "                                 patch_len=patch_len, patch_stride=patch_stride,\n",
    "                                 fusion_layers=fusion_layers, fusion_act=fusion_act, fusion_dropout=fusion_dropout, fusion_use_bn=fusion_use_bn,**arch_config,\n",
    "                                 **kwargs)\n",
    "    else:\n",
    "        if d and arch.__name__ not in [\"PatchTST\", \"PatchTSTPlus\", 'TransformerRNNPlus', 'TransformerLSTMPlus', 'TransformerGRUPlus',\n",
    "        \"RNN_FCNPlus\", \"LSTM_FCNPlus\", \"GRU_FCNPlus\", \"MRNN_FCNPlus\", \"MLSTM_FCNPlus\", \"MGRU_FCNPlus\",\n",
    "        \"RNNAttentionPlus\", \"LSTMAttentionPlus\", \"GRUAttentionPlus\", \"ConvTran\", \"ConvTranPlus\"]:\n",
    "            if 'custom_head' not in kwargs.keys():\n",
    "                if \"rocket\" in arch.__name__.lower() or 'hydra' in arch.__name__.lower():\n",
    "                    kwargs['custom_head'] = partial(rocket_nd_head, d=d)\n",
    "                elif \"xresnet1d\" in arch.__name__.lower():\n",
    "                    kwargs[\"custom_head\"] = partial(xresnet1d_nd_head, d=d)\n",
    "                else:\n",
    "                    kwargs['custom_head'] = partial(lin_nd_head, d=d)\n",
    "            elif not isinstance(kwargs['custom_head'], nn.Module):\n",
    "                kwargs['custom_head'] = partial(kwargs['custom_head'], d=d)\n",
    "        if 'ltsf_' in arch.__name__.lower() or 'patchtst' in arch.__name__.lower():\n",
    "            pv(f'arch: {arch.__name__}(c_in={c_in} c_out={c_out} seq_len={seq_len} pred_dim={d} arch_config={arch_config}, kwargs={kwargs})', verbose)\n",
    "            model = (arch(c_in=c_in, c_out=c_out, seq_len=seq_len, pred_dim=d, **arch_config, **kwargs)).to(device=device)\n",
    "        elif arch.__name__ in ['TransformerRNNPlus', 'TransformerLSTMPlus', 'TransformerGRUPlus', \"RNN_FCNPlus\", \"LSTM_FCNPlus\", \"GRU_FCNPlus\", \"MRNN_FCNPlus\",\n",
    "        \"MLSTM_FCNPlus\", \"MGRU_FCNPlus\", \"RNNAttentionPlus\", \"LSTMAttentionPlus\", \"GRUAttentionPlus\", \"ConvTran\", \"ConvTranPlus\", 'mWDNPlus']:\n",
    "            pv(f'arch: {arch.__name__}(c_in={c_in} c_out={c_out} seq_len={seq_len} d={d} arch_config={arch_config}, kwargs={kwargs})', verbose)\n",
    "            model = (arch(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, **arch_config, **kwargs)).to(device=device)\n",
    "        elif sum([1 for v in ['RNN_FCN', 'LSTM_FCN', 'RNNPlus', 'LSTMPlus', 'GRUPlus', 'InceptionTime', 'TSiT', 'Sequencer', 'XceptionTimePlus',\n",
    "                            'GRU_FCN', 'OmniScaleCNN', 'mWDN', 'TST', 'XCM', 'MLP', 'MiniRocket', 'InceptionRocket', 'ResNetPlus',\n",
    "                            'RNNAttention', 'LSTMAttention', 'GRUAttention', 'MultiRocket', 'MultiRocketPlus', 'Hydra', 'HydraPlus',\n",
    "                            'HydraMultiRocket', 'HydraMultiRocketPlus']\n",
    "                if v in arch.__name__]):\n",
    "            pv(f'arch: {arch.__name__}(c_in={c_in} c_out={c_out} seq_len={seq_len} arch_config={arch_config} kwargs={kwargs})', verbose)\n",
    "            model = arch(c_in, c_out, seq_len=seq_len, **arch_config, **kwargs).to(device=device)\n",
    "        elif 'xresnet' in arch.__name__ and not '1d' in arch.__name__:\n",
    "            pv(f'arch: {arch.__name__}(c_in={c_in} n_out={c_out} arch_config={arch_config} kwargs={kwargs})', verbose)\n",
    "            model = (arch(c_in=c_in, n_out=c_out, **arch_config, **kwargs)).to(device=device)\n",
    "        elif 'xresnet1d' in arch.__name__.lower():\n",
    "            pv(f'arch: {arch.__name__}(c_in={c_in} c_out={c_out} seq_len={seq_len} arch_config={arch_config} kwargs={kwargs})', verbose)\n",
    "            model = (arch(c_in=c_in, c_out=c_out, seq_len=seq_len, **arch_config, **kwargs)).to(device=device)\n",
    "        elif 'minirockethead' in arch.__name__.lower():\n",
    "            pv(f'arch: {arch.__name__}(c_in={c_in} seq_len={seq_len} arch_config={arch_config} kwargs={kwargs})', verbose)\n",
    "            model = (arch(c_in, c_out, seq_len=1, **arch_config, **kwargs)).to(device=device)\n",
    "        elif 'rocket' in arch.__name__.lower():\n",
    "            pv(f'arch: {arch.__name__}(c_in={c_in} seq_len={seq_len} arch_config={arch_config} kwargs={kwargs})', verbose)\n",
    "            model = (arch(c_in=c_in, seq_len=seq_len, **arch_config, **kwargs)).to(device=device)\n",
    "        else:\n",
    "            pv(f'arch: {arch.__name__}(c_in={c_in} c_out={c_out} arch_config={arch_config} kwargs={kwargs})', verbose)\n",
    "            model = arch(c_in, c_out, **arch_config, **kwargs).to(device=device)\n",
    "\n",
    "    try:\n",
    "        model[0], model[1]\n",
    "        subscriptable = True\n",
    "    except:\n",
    "        subscriptable = False\n",
    "    if hasattr(model, \"head_nf\"):  head_nf = model.head_nf\n",
    "    else:\n",
    "        try: head_nf = get_nf(model)\n",
    "        except: head_nf = None\n",
    "\n",
    "    if not subscriptable and 'Plus' in arch.__name__:\n",
    "        model = nn.Sequential(*model.children())\n",
    "        model.backbone = model[:cut]\n",
    "        model.head = model[cut:]\n",
    "\n",
    "    if pretrained and not ('xresnet' in arch.__name__ and not '1d' in arch.__name__):\n",
    "        assert weights_path is not None, \"you need to pass a valid weights_path to use a pre-trained model\"\n",
    "        transfer_weights(model, weights_path, exclude_head=exclude_head, device=device)\n",
    "\n",
    "    if init is not None:\n",
    "        apply_init(model[1] if pretrained else model, init)\n",
    "\n",
    "    setattr(model, \"head_nf\", head_nf)\n",
    "    setattr(model, \"__name__\", arch.__name__)\n",
    "\n",
    "    return model\n",
    "\n",
    "build_model = build_ts_model\n",
    "create_model = build_ts_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.data.core import get_ts_dls, TSClassification\n",
    "from tsai.models.TSiTPlus import TSiTPlus\n",
    "from fastai.losses import CrossEntropyLossFlat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "arch: TSiTPlus(c_in=3 c_out=2 seq_len=128 arch_config={} kwargs={'custom_head': functools.partial(<class 'tsai.models.layers.lin_nd_head'>, d=3)})\n",
      "torch.Size([13, 3, 2])\n",
      "TensorBase(0.8796, grad_fn=<AliasBackward0>)\n"
     ]
    }
   ],
   "source": [
    "X = np.random.rand(16, 3, 128)\n",
    "y = np.random.randint(0, 2, (16, 3))\n",
    "tfms = [None, [TSClassification()]]\n",
    "dls = get_ts_dls(X, y, splits=RandomSplitter()(range_of(X)), tfms=tfms)\n",
    "model = build_ts_model(TSiTPlus, dls=dls, pretrained=False, verbose=True)\n",
    "xb, yb = dls.one_batch()\n",
    "output = model(xb)\n",
    "print(output.shape)\n",
    "loss = CrossEntropyLossFlat()(output, yb)\n",
    "print(loss)\n",
    "assert output.shape == (dls.bs, dls.d, dls.c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def count_parameters(model, trainable=True):\n",
    "    if trainable: return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    else: return sum(p.numel() for p in model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "# @delegates(XResNet.__init__)\n",
    "def build_tsimage_model(arch, c_in=None, c_out=None, dls=None, pretrained=False, device=None, verbose=False, init=None, arch_config={}, **kwargs):\n",
    "    device = ifnone(device, default_device())\n",
    "    if dls is not None:\n",
    "        c_in = ifnone(c_in, dls.vars)\n",
    "        c_out = ifnone(c_out, dls.c)\n",
    "\n",
    "    model = arch(pretrained=pretrained, c_in=c_in, n_out=c_out, **arch_config, **kwargs).to(device=device)\n",
    "    setattr(model, \"__name__\", arch.__name__)\n",
    "    if init is not None:\n",
    "        apply_init(model[1] if pretrained else model, init)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "# @delegates(TabularModel.__init__)\n",
    "def build_tabular_model(arch, dls, layers=None, emb_szs=None, n_out=None, y_range=None, device=None, arch_config={}, **kwargs):\n",
    "    if device is None: device = default_device()\n",
    "    if layers is None: layers = [200,100]\n",
    "    emb_szs = get_emb_sz(dls.train_ds, {} if emb_szs is None else emb_szs)\n",
    "    if n_out is None: n_out = get_c(dls)\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    if y_range is None and 'y_range' in kwargs: y_range = kwargs.pop('y_range')\n",
    "    model = arch(emb_szs, len(dls.cont_names), n_out, layers, y_range=y_range, **arch_config, **kwargs).to(device=device)\n",
    "\n",
    "    if hasattr(model, \"head_nf\"):  head_nf = model.head_nf\n",
    "    else: head_nf = get_nf(model)\n",
    "    setattr(model, \"__name__\", arch.__name__)\n",
    "    if head_nf is not None: setattr(model, \"head_nf\", head_nf)\n",
    "    return model\n",
    "\n",
    "create_tabular_model = build_tabular_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.data.external import get_UCR_data\n",
    "from tsai.data.core import TSCategorize, get_ts_dls\n",
    "from tsai.data.preprocessing import TSStandardize\n",
    "from tsai.models.InceptionTime import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, splits = get_UCR_data('NATOPS', split_data=False)\n",
    "tfms = [None, TSCategorize()]\n",
    "batch_tfms = TSStandardize()\n",
    "dls = get_ts_dls(X, y, splits, tfms=tfms, batch_tfms=batch_tfms)\n",
    "model = build_ts_model(InceptionTime, dls=dls)\n",
    "test_eq(count_parameters(model), 460038)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def get_clones(module, N):\n",
    "    return nn.ModuleList([deepcopy(module) for i in range(N)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModuleList(\n",
       "  (0-2): 3 x Conv1d(3, 4, kernel_size=(3,), stride=(1,))\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = nn.Conv1d(3,4,3)\n",
    "get_clones(m, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def split_model(m): return m.backbone, m.head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "@torch.no_grad()\n",
    "def output_size_calculator(mod, c_in, seq_len=None):\n",
    "    assert isinstance(mod, nn.Module)\n",
    "    return_q_len = True\n",
    "    if seq_len is None:\n",
    "        seq_len = 50\n",
    "        return_q_len = False\n",
    "    try:\n",
    "        params_0 = list(mod.parameters())[0]\n",
    "        xb = torch.rand(1, c_in, seq_len, device=params_0.device, dtype=params_0.dtype)\n",
    "    except:\n",
    "        xb = torch.rand(1, c_in, seq_len)\n",
    "    training = mod.training\n",
    "    mod.eval()\n",
    "    output_shape = tuple(mod.to(xb.device)(xb).shape)\n",
    "    if len(output_shape) == 2:\n",
    "        c_out, q_len = output_shape[1], None\n",
    "    else:\n",
    "        c_out, q_len = output_shape[1:]\n",
    "    mod.training = training\n",
    "    if return_q_len:\n",
    "        return c_out, q_len\n",
    "    else:\n",
    "        return c_out, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_in = 3\n",
    "seq_len = 30\n",
    "m = nn.Conv1d(3, 12, kernel_size=3, stride=2)\n",
    "new_c_in, new_seq_len = output_size_calculator(m, c_in, seq_len)\n",
    "test_eq((new_c_in, new_seq_len), (12, 14))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def change_model_head(model, custom_head, **kwargs):\n",
    "    r\"\"\"Replaces a model's head by a custom head as long as the model has a head, head_nf, c_out and seq_len attributes\"\"\"\n",
    "    model.head = custom_head(model.head_nf, model.c_out, model.seq_len, **kwargs)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def naive_forecaster(o, split, horizon=1):\n",
    "    if is_listy(horizon):\n",
    "        _f = []\n",
    "        for h in horizon:\n",
    "            _f.append(o[np.asarray(split)-h])\n",
    "        return np.stack(_f)\n",
    "    return o[np.asarray(split) - horizon]\n",
    "\n",
    "def true_forecaster(o, split, horizon=1):\n",
    "    o_true = o[split]\n",
    "    if is_listy(horizon):\n",
    "        o_true = o_true[np.newaxis].repeat(len(horizon), 0)\n",
    "    return o_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0.99029138, 1.68463991, 2.21744589, 2.65448222, 2.85159354,\n",
       "        3.26171729, 3.67986707, 4.04343956, 4.3077458 , 4.44585435,\n",
       "        4.76876866, 4.85844441, 4.93256093, 5.52415845, 6.10704489,\n",
       "        6.74848957, 7.31920741, 8.20198208, 8.78954283, 9.0402    ]),\n",
       " array([4.44585435, 4.76876866, 4.85844441, 4.93256093, 5.52415845,\n",
       "        6.10704489, 6.74848957, 7.31920741, 8.20198208, 8.78954283]),\n",
       " array([4.76876866, 4.85844441, 4.93256093, 5.52415845, 6.10704489,\n",
       "        6.74848957, 7.31920741, 8.20198208, 8.78954283, 9.0402    ]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = np.random.rand(20).cumsum()\n",
    "split = np.arange(10, 20)\n",
    "a, naive_forecaster(a, split, 1), true_forecaster(a, split, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/030_models.utils.ipynb saved at 2025-01-22 18:23:18\n",
      "Correct notebook to script conversion! 😃\n",
      "Wednesday 22/01/25 18:23:21 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
